{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 4장: 모델 평균화를 통한 연합 학습\n",
    "\n",
    "**복습** 본 튜토리얼의 2장에서는 매우 간단한 버전의 연합 학습을 사용하여 모델을 학습했습니다. 이를 위해 모델 소유자는 각 데이터 소유자의 그래디언트(gradient)에 대한 정보를 알 수 있었고, 데이터 소유자들은 모델 소유자를 신뢰해야 했습니다.\n",
    "\n",
    "**설명:** 본 튜토리얼에서는 3장에서 소개했던 고급 집계 방법들을 응용해서, 최종 모델이 모델 소유자 (우리)에게 보내지기 전에 신뢰할 수 있는 \"보안 작업자\"가 모델의 가중치들을 집계하는 방법을 소개할 것입니다.\n",
    "\n",
    "이렇게 해서, 어느 가중치가 누구의 것인지 오직 보안 작업자만이 알 수 있습니다. 우리는 모델의 어느 부분이 바뀌었는지는 알 수 있어도, 어떤 작업자(Bob 또는 Alice)가 무엇을 바꿨는지는 알 수 없고, 이렇게 추가적인 사생활 보호가 가능합니다.\n",
    "\n",
    "저자:\n",
    " - Andrew Trask - Twitter: [@iamtrask](https://twitter.com/iamtrask)\n",
    " - Jason Mancuso - Twitter: [@jvmancuso](https://twitter.com/jvmancuso)\n",
    "\n",
    "역자:\n",
    " - Seungjae Ryan Lee - Twitter: [@seungjaeryanlee](https://twitter.com/seungjaeryanlee)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import syft as sy\n",
    "import copy\n",
    "hook = sy.TorchHook(torch)\n",
    "from torch import nn, optim"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 1단계: 데이터 소유자 만들기\n",
    "\n",
    "우선, 각각 다른 데이터를 가진 두 명의 데이터 소유자(Bob과 Alice)를 생성하겠습니다. 또, \"secure_worker\"이라는 보안 기계도 만들 겁니다. 실제로 이 역할은 (Intel의 SGX와 같은) 안전한 하드웨어나 단순히 모두가 신뢰할 수 있는 중재자가 할 수 있습니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create a couple workers\n",
    "\n",
    "bob = sy.VirtualWorker(hook, id=\"bob\")\n",
    "alice = sy.VirtualWorker(hook, id=\"alice\")\n",
    "secure_worker = sy.VirtualWorker(hook, id=\"secure_worker\")\n",
    "\n",
    "\n",
    "# A Toy Dataset\n",
    "data = torch.tensor([[0,0],[0,1],[1,0],[1,1.]], requires_grad=True)\n",
    "target = torch.tensor([[0],[0],[1],[1.]], requires_grad=True)\n",
    "\n",
    "# get pointers to training data on each worker by\n",
    "# sending some training data to bob and alice\n",
    "bobs_data = data[0:2].send(bob)\n",
    "bobs_target = target[0:2].send(bob)\n",
    "\n",
    "alices_data = data[2:].send(alice)\n",
    "alices_target = target[2:].send(alice)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2단계: 모델 생성\n",
    "\n",
    "이 예제에서는 간단한 선형 모델을 학습할 것입니다. 이 모델은 PyTorch의 nn.Linear을 사용해서 생성할 수 있습니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Iniitalize A Toy Model\n",
    "model = nn.Linear(2,1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 3단계: 모델의 사본을 Alice 와 Bob에게 보내기\n",
    "\n",
    "다음으로, 현재 모델의 사본을 Alice와 Bob에게 보내서, 각각 자신의 데이터셋에서 학습을 수행할 수 있게 합니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bobs_model = model.copy().send(bob)\n",
    "alices_model = model.copy().send(alice)\n",
    "\n",
    "bobs_opt = optim.SGD(params=bobs_model.parameters(),lr=0.1)\n",
    "alices_opt = optim.SGD(params=alices_model.parameters(),lr=0.1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 4단계: Bob과 Alice의 모델 (병렬)학습시키기\n",
    "\n",
    "안전한 평균화를 통한 연합 학습에서 늘 하는 것처럼, 각 데이터 소유자는 모델들의 평균을 내기 전에 먼저 여러 번 자신의 모델을 학습합니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(10):\n",
    "\n",
    "    # Train Bob's Model\n",
    "    bobs_opt.zero_grad()\n",
    "    bobs_pred = bobs_model(bobs_data)\n",
    "    bobs_loss = ((bobs_pred - bobs_target)**2).sum()\n",
    "    bobs_loss.backward()\n",
    "\n",
    "    bobs_opt.step()\n",
    "    bobs_loss = bobs_loss.get().data\n",
    "\n",
    "    # Train Alice's Model\n",
    "    alices_opt.zero_grad()\n",
    "    alices_pred = alices_model(alices_data)\n",
    "    alices_loss = ((alices_pred - alices_target)**2).sum()\n",
    "    alices_loss.backward()\n",
    "\n",
    "    alices_opt.step()\n",
    "    alices_loss = alices_loss.get().data\n",
    "    \n",
    "    print(\"Bob:\" + str(bobs_loss) + \" Alice:\" + str(alices_loss))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 5단계: 업데이트된 두 모델을 보안 작업자에게 전송하기\n",
    "\n",
    "이제 각 데이터 소유자는 부분적으로 훈련된 모델을 가지고 있으므로, 안전하게 이 모델들의 평균을 내야 합니다. 우리는 Alice와 Bob에게 그들의 모델을 안전한(신뢰할 수 있는) 서버로 보내라고 지시함으로써 이를 성취할 수 있습니다.\n",
    "\n",
    "우리 API의 사용은 각 모델이 secure_worker에게 **직접적으로** 전송된다는 것을 의미합니다. 우리는 결코 그것을 보지 못합니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "alices_model.move(secure_worker)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bobs_model.move(secure_worker)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 6단계: 모델 평균내기"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "이 학습 epoch의 마지막 단계는 Bob과 Alice가 훈련한 모델들을 평균화한 것을 이용해서 글로벌 \"모델\"의 값들을 설정하는 것입니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    model.weight.set_(((alices_model.weight.data + bobs_model.weight.data) / 2).get())\n",
    "    model.bias.set_(((alices_model.bias.data + bobs_model.bias.data) / 2).get())\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 반복하기\n",
    "\n",
    "이제 우리는 이것을 여러 번 반복하면 됩니다!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "iterations = 10\n",
    "worker_iters = 5\n",
    "\n",
    "for a_iter in range(iterations):\n",
    "    \n",
    "    bobs_model = model.copy().send(bob)\n",
    "    alices_model = model.copy().send(alice)\n",
    "\n",
    "    bobs_opt = optim.SGD(params=bobs_model.parameters(),lr=0.1)\n",
    "    alices_opt = optim.SGD(params=alices_model.parameters(),lr=0.1)\n",
    "\n",
    "    for wi in range(worker_iters):\n",
    "\n",
    "        # Train Bob's Model\n",
    "        bobs_opt.zero_grad()\n",
    "        bobs_pred = bobs_model(bobs_data)\n",
    "        bobs_loss = ((bobs_pred - bobs_target)**2).sum()\n",
    "        bobs_loss.backward()\n",
    "\n",
    "        bobs_opt.step()\n",
    "        bobs_loss = bobs_loss.get().data\n",
    "\n",
    "        # Train Alice's Model\n",
    "        alices_opt.zero_grad()\n",
    "        alices_pred = alices_model(alices_data)\n",
    "        alices_loss = ((alices_pred - alices_target)**2).sum()\n",
    "        alices_loss.backward()\n",
    "\n",
    "        alices_opt.step()\n",
    "        alices_loss = alices_loss.get().data\n",
    "    \n",
    "    alices_model.move(secure_worker)\n",
    "    bobs_model.move(secure_worker)\n",
    "    with torch.no_grad():\n",
    "        model.weight.set_(((alices_model.weight.data + bobs_model.weight.data) / 2).get())\n",
    "        model.bias.set_(((alices_model.bias.data + bobs_model.bias.data) / 2).get())\n",
    "    \n",
    "    print(\"Bob:\" + str(bobs_loss) + \" Alice:\" + str(alices_loss))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "마지막으로, 우리의 모델이 올바르게 학습되었는지 확인하기 위해, 테스트 데이터셋 (test dataset)으로 이 모델을 평가할 것입니다. 이 간단한 예제에서는 기존 데이터를 사용하지만, 실제로는 새로운 데이터를 사용하여 이 모델이 본 적 없는 데이터에도 얼마나 잘 일반화하는지 확인합니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "preds = model(data)\n",
    "loss = ((preds - target) ** 2).sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(preds)\n",
    "print(target)\n",
    "print(loss.data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "이 간단한 예제에서, 평균화된 모델은 데이터 소유자가 학습시킨 모델에 비해서 underfit 합니다. 대신 우리는 각 데이터 소유자의 학습 데이터를 노출시키지 않고 이 모델을 훈련시킬 수 있었습니다. 또한 우리는 모델 소유자에게 데이터가 유출되는 것을 방지하기 위해 각 데이터 소유자의 업데이트된 모델을 신뢰할 수 있는 집계기에서 통합했습니다.\n",
    "\n",
    "향후의 튜토리얼에서는, 이러한 신뢰할 수 있는 집계 방식을 그래디언트에 사용하는 법을 배울 것입니다. 이렇게 함으로써 더 나은 그래디언트로 모델을 업데이트 함으로써 더 강력한 모델을 얻을 수 있습니다."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 축하합니다!!! - 이제 커뮤니티에 가입할 시간입니다!\n",
    "\n",
    "이 파이썬 노트북 튜토리얼을 완료한 것을 축하합니다! 이 튜토리얼에 만족하셨다면, 그리고 사생활을 보호하는 분산소유 형태의 AI와 AI 공급망(데이터)을 위한 움직임에 동참하고 싶다면 다음과 같은 방법이 있습니다!\n",
    "\n",
    "### PySyft GitHub의 Star 하세요\n",
    "\n",
    "커뮤니티를 돕는 가장 쉬운 방법은 GitHub 리포지토리를 널리 알리는 것입니다! 이를 통해 우리가 만드는 멋진 도구에 대한 인식을 높일 수 있습니다.\n",
    "\n",
    "- [Star PySyft](https://github.com/OpenMined/PySyft)\n",
    "\n",
    "### 우리 Slack에 가입하세요!\n",
    "\n",
    "최신 소식을 받는 가장 좋은 방법은 커뮤니티에 가입하는 것입니다! [http://slack.openmined.org](http://slack.openmined.org)에서 양식을 작성하면 됩니다.\n",
    "\n",
    "### 프로젝트에 참여하십시오!\n",
    "\n",
    "커뮤니티에 기여하는 가장 좋은 방법은 코드에 기여하는 것입니다! 언제든지 PySyft GitHub 이슈 페이지로 이동하여 \"Projects\"를 필터링 할 수 있습니다. 그러면 참여할 수 있는 프로젝트에 대한 개요를 제공하는 모든 최상위 이슈 티켓이 표시됩니다. 프로젝트에 참여하고 싶지는 않지만 약간의 코딩을 원한다면 \"good first issue\"로 표시된 GitHub 이슈를 검색하여 더 많은 \"일회용\" 미니 프로젝트를 찾을 수도 있습니다.\n",
    "\n",
    "- [PySyft Projects](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3AProject)\n",
    "- [Good First Issue Tickets](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n",
    "\n",
    "### 기부하기\n",
    "\n",
    "코드에 기여할 시간이 없지만 지원을 하고 싶다면 Open Collective의 후원자가 될 수도 있습니다. 모든 기부금은 웹 호스팅 및 해커톤 및 밋업과 같은 기타 커뮤니티 비용으로 들어갑니다!\n",
    "\n",
    "[OpenMined's Open Collective Page](https://opencollective.com/openmined)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
